from typing import Union

import cv2

__all__ = [
    "INPUT_IMAGE_ID_KEY",
    "INPUT_IMAGE_KEY",
    "INPUT_INDEX_KEY",
    "OUTPUT_EMBEDDINGS_KEY",
    "OUTPUT_LOGITS_KEY",
    "OUTPUT_MASK_KEY",
    "OUTPUT_MASK_KEY_STRIDE_16",
    "OUTPUT_MASK_KEY_STRIDE_2",
    "OUTPUT_MASK_KEY_STRIDE_32",
    "OUTPUT_MASK_KEY_STRIDE_4",
    "OUTPUT_MASK_KEY_STRIDE_64",
    "OUTPUT_MASK_KEY_STRIDE_8",
    "TARGET_CLASS_KEY",
    "TARGET_LABELS_KEY",
    "TARGET_MASK_KEY",
    "TARGET_MASK_KEY_STRIDE_16",
    "TARGET_MASK_KEY_STRIDE_2",
    "TARGET_MASK_KEY_STRIDE_32",
    "TARGET_MASK_KEY_STRIDE_4",
    "TARGET_MASK_KEY_STRIDE_64",
    "TARGET_MASK_KEY_STRIDE_8",
    "TARGET_MASK_WEIGHT_KEY",
    "name_for_stride",
    "read_image_rgb",
]


def name_for_stride(name, stride: Union[int, None]) -> str:
    if stride is None:
        return name
    return f"{name}_STRIDE_{stride}"


INPUT_INDEX_KEY = "INPUT_INDEX_KEY"
INPUT_IMAGE_KEY = "INPUT_IMAGE_KEY"
INPUT_IMAGE_ID_KEY = "INPUT_IMAGE_ID_KEY"

TARGET_MASK_WEIGHT_KEY = "TARGET_MASK_WEIGHT_KEY"
TARGET_CLASS_KEY = "TARGET_CLASS_KEY"
TARGET_LABELS_KEY = "TARGET_LABELS_KEY"

TARGET_MASK_KEY = "TARGET_MASK_KEY"

TARGET_MASK_KEY_STRIDE_2 = name_for_stride(TARGET_MASK_KEY, 2)
TARGET_MASK_KEY_STRIDE_4 = name_for_stride(TARGET_MASK_KEY, 4)
TARGET_MASK_KEY_STRIDE_8 = name_for_stride(TARGET_MASK_KEY, 8)
TARGET_MASK_KEY_STRIDE_16 = name_for_stride(TARGET_MASK_KEY, 16)
TARGET_MASK_KEY_STRIDE_32 = name_for_stride(TARGET_MASK_KEY, 32)
TARGET_MASK_KEY_STRIDE_64 = name_for_stride(TARGET_MASK_KEY, 64)

OUTPUT_MASK_KEY = "OUTPUT_MASK_KEY"
OUTPUT_MASK_KEY_STRIDE_2 = name_for_stride(OUTPUT_MASK_KEY, 2)
OUTPUT_MASK_KEY_STRIDE_4 = name_for_stride(OUTPUT_MASK_KEY, 4)
OUTPUT_MASK_KEY_STRIDE_8 = name_for_stride(OUTPUT_MASK_KEY, 8)
OUTPUT_MASK_KEY_STRIDE_16 = name_for_stride(OUTPUT_MASK_KEY, 16)
OUTPUT_MASK_KEY_STRIDE_32 = name_for_stride(OUTPUT_MASK_KEY, 32)
OUTPUT_MASK_KEY_STRIDE_64 = name_for_stride(OUTPUT_MASK_KEY, 64)

OUTPUT_LOGITS_KEY = "OUTPUT_LOGITS_KEY"
OUTPUT_EMBEDDINGS_KEY = "OUTPUT_EMBEDDINGS_KEY"


def read_image_rgb(fname: str):
    image = cv2.imread(fname)[..., ::-1]
    if image is None:
        raise IOError("Cannot read " + fname)
    return image
